/*
* GradientGraphCombFuncNode.java 
* 
* Copyright (C) 2009 Aalborg University
*
* contact:
* jaeger@cs.aau.dk   http://www.cs.aau.dk/~jaeger/Primula.html
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License
* as published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
*/

package RBNLearning;

import java.util.*;
import java.io.*;
import RBNpackage.*;
import RBNgui.*;
import RBNExceptions.*;
import RBNutilities.*;
import RBNinference.*;


public class GradientGraphCombFuncNode extends GradientGraphProbFormNode{

	int typeOfComb; /* The combination function used */

	double[] valuesOfSubPFs; /* Contains the values of those ProbForms in the argument
	 * of the combination function that do not depend on 
	 * unobserved probabilistic atoms or unknown parameters
	 */

	public GradientGraphCombFuncNode(GradientGraph gg,
			ProbForm pf, 
			Hashtable allnodes,
			RelStruc A,
			OneStrucData I,
			int inputcaseno,
			int observcaseno)
	throws RBNCompatibilityException
	{
		super(gg,pf,A,I);

		DoubleVector vals = new DoubleVector();

		ProbFormCombFunc pfcomb = (ProbFormCombFunc)pf;
		typeOfComb = pfcomb.myCombInt();

		int[][] subslist = pfcomb.tuplesSatisfyingCConstr(A, new String[0], new int[0]);;

		children = new Vector<GradientGraphProbFormNode>();


		/* For all probability formulas in the argument of the combination function, and all
		 * tuples in subslist: Check whether the formula evaluates for this tuple to a constant.
		 * If yes, add this value to vals; if no, create a new child node for this probability
		 * formula
		 */
		ProbForm nextsubpf;
		ProbForm groundnextsubpf;
		double evalOfSubPF;
		boolean dependsOfSubPF;
		GradientGraphProbFormNode constructedchild;
		
		for (int i=0; i<pfcomb.numPFargs(); i++)
		{
			nextsubpf = pfcomb.probformAt(i);
			for (int j=0; j<subslist.length; j++)
			{
				groundnextsubpf = nextsubpf.substitute(pfcomb.quantvars(),subslist[j]);
				evalOfSubPF = groundnextsubpf.evaluate(A, I, new String[0], new int[0] , false);

				if (evalOfSubPF == -1){
					constructedchild = GradientGraphProbFormNode.constructGGPFN(gg,
							groundnextsubpf,
							allnodes, 
							A, 
							I,
							inputcaseno,
							observcaseno,
							false,
							"");
					children.add(constructedchild);
					constructedchild.addToParents(this);
				}
				else
					vals.add(evalOfSubPF);

			} /* for (int j=0; j<subslist.length; j++) */
		}/* for (int i=0; i<pfcomb.numPFargs(); i++) */
		valuesOfSubPFs = vals.asArray();
	}


	private double computeCombFunc(double[] args){
		double result = 0;
		switch (typeOfComb){
		case ProbFormCombFunc.NOR:
			result = thisgg.computeCombFunc(Primula.CF_NOR,args);
			break;
		case ProbFormCombFunc.MEAN: 
			result = thisgg.computeCombFunc(Primula.CF_MEAN,args);
			break;
		case ProbFormCombFunc.INVSUM:
			result = thisgg.computeCombFunc(Primula.CF_INVSUM,args);;
			break;
		case ProbFormCombFunc.ESUM:
			result = thisgg.computeCombFunc(Primula.CF_ESUM,args);;
			break;
		}
		return result;
	}

	public double evaluate(){


		
		if (value != null)
			return value;

		/* Construct an argument array for the combination function: */
		double[] args = new double[valuesOfSubPFs.length+ children.size()];
		for (int i=0;i<valuesOfSubPFs.length;i++)
			args[i]=valuesOfSubPFs[i];
		for (int i=0;i<children.size();i++)
			args[i+valuesOfSubPFs.length]= children.elementAt(i).evaluate();
		double result = computeCombFunc(args);
		value = result;
		if (result == Double.NaN)
			System.out.println("result in evaluate for comb.func " + this.name() +" : " + result);
		return result;
	}

	public void evaluateBounds(){
		if (bounds[0]== -1){
			//	    System.out.println("combfuncnode.evaluateBounds");
			/* First set bounds at children: */
			for (int i=0;i<children.size();i++)
				children.elementAt(i).evaluateBounds();
			/* Determine arrays of lower and upper bounds for sub-formulas */
			double[] lowargs = new double[valuesOfSubPFs.length+ children.size()];
			for (int i=0;i<valuesOfSubPFs.length;i++)
				lowargs[i]=valuesOfSubPFs[i];
			for (int i=0;i<children.size();i++)
				lowargs[i+valuesOfSubPFs.length]= children.elementAt(i).lowerBound();
			double[] uppargs = new double[valuesOfSubPFs.length+ children.size()];
			for (int i=0;i<valuesOfSubPFs.length;i++)
				uppargs[i]=valuesOfSubPFs[i];
			for (int i=0;i<children.size();i++)
				uppargs[i+valuesOfSubPFs.length]= children.elementAt(i).upperBound();
			/* NOR and MEAN are monotone in their arguments, INVSUM anti-monotone */
			switch (typeOfComb){
			case ProbFormCombFunc.NOR:
				bounds[0] = thisgg.computeCombFunc(Primula.CF_NOR,lowargs);
				bounds[1] = thisgg.computeCombFunc(Primula.CF_NOR,uppargs);
				break;
			case ProbFormCombFunc.MEAN: 
				bounds[0] = thisgg.computeCombFunc(Primula.CF_MEAN,lowargs);
				bounds[1] = thisgg.computeCombFunc(Primula.CF_MEAN,uppargs);
				break;
			case ProbFormCombFunc.INVSUM:		
				bounds[0] = thisgg.computeCombFunc(Primula.CF_INVSUM,uppargs);
				bounds[1] = thisgg.computeCombFunc(Primula.CF_INVSUM,lowargs);
				break;
			}
		}
	}

	public double evaluateGrad(int param){
		if (gradient[param] != null)
			return gradient[param];
		if (!dependsOnParam[param])
			return 0.0;
		double result = 0;
		switch (typeOfComb){
		case ProbFormCombFunc.NOR:
			result = computeDerivNOR(param);
			break;

		case ProbFormCombFunc.MEAN:
			result = computeDerivMEAN(param);
			break;

		case ProbFormCombFunc.INVSUM:
			result = computeDerivINVSUM(param);
			break;

		case ProbFormCombFunc.ESUM:
			result = computeDerivESUM(param);
			break;

		}
		gradient[param] = result;
		return result;
	}



	private double computeDerivNOR(int param){
		double result = 0;
        /* First compute \prod (1-Fi) over all subformulas */
        double factor =1;
        for (int i=0;i<valuesOfSubPFs.length;i++)
                factor = factor*(1-valuesOfSubPFs[i]);
        for (int i=0;i<children.size();i++)
                factor = factor*(1-children.elementAt(i).evaluate());
        
        if (factor == 0)
        	return 0.0;
        
        /* Now compute the partial derivative as
         *
         * \sum_{F_i\in fthetalist} (factor/(1-F_i))*(F_i')
         */
        for (int i=0;i<children.size();i++){
                if (children.elementAt(i).dependsOn(param))
                        result = result + (factor/(1-children.elementAt(i).evaluate()))*(children.elementAt(i).evaluateGrad(param));
        }
        
		return result;
	}


	private double computeDerivMEAN(int param ){
		double result = 0;
		for (int i=0;i<children.size();i++)
			result = result + children.elementAt(i).evaluateGrad(param);
		result = result/(valuesOfSubPFs.length + children.size());
		return result;
	}

	private double computeValueINVSUM(){
		double[] args = new double[valuesOfSubPFs.length+ children.size()];
		for (int i=0;i<valuesOfSubPFs.length;i++)
			args[i]=valuesOfSubPFs[i];
		for (int i=0;i<children.size();i++)
			args[i+valuesOfSubPFs.length]= children.elementAt(i).evaluate();
		return thisgg.computeCombFunc(Primula.CF_INVSUM,args);
	}

	private double computeDerivINVSUM(int param ){
		double result = 0;
		double val = this.evaluate();
		if (val == 1.0)
			return 0;
		else{
			double derivsum = 0;
			result = -Math.pow(val,-2);
			for (int i=0;i<children.size();i++)
				derivsum = derivsum + children.elementAt(i).evaluateGrad(param);
			result = result*derivsum;
		}
		return result;
	}

	private double computeDerivESUM(int param ){
		double result = 0;
		double val = this.evaluate();

		double derivsum = 0;
		result = -val;
		for (int i=0;i<children.size();i++)
			derivsum = derivsum + children.elementAt(i).evaluateGrad(param);
		result = result*derivsum;

		return result;
	}



}
